#!/usr/bin/env python3
# H21 — Atomic Plateaus & Rims
# CONTROL: present-act, boolean/ordinal; deterministic DDA 1/r per shell; NO curves/weights; NO RNG in control.
# GEOMETRY: one center; a set of radial "bands" (plateaus). The outer band optionally has azimuthal segments
#           (periodic keep pattern across sectors); the inner band is isotropic.
# READOUTS (diagnostics-only):
#   • Radial plateaus: equal-Δr bin profile; per-band CV small; gap amplitude small.
#   • Rim segments: on the designated band, strong K-fold periodic envelope (R^2_K high; visibility high).
#   • Inner-band azimuth flatness small (no spurious anisotropy).
#
# PASS when all per-band CV/gap/segment tests are green. Control remains theory-pure.

import argparse, csv, hashlib, json, math, os, sys
from datetime import datetime, timezone
from typing import Dict, List, Tuple

# ---------- utils ----------
def utc_ts() -> str:
    return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")

def ensure_dirs(root:str, subs:List[str]) -> None:
    for s in subs: os.makedirs(os.path.join(root, s), exist_ok=True)

def wtxt(path:str, txt:str) -> None:
    with open(path, "w", encoding="utf-8") as f: f.write(txt)

def jdump(path:str, obj:dict) -> None:
    with open(path, "w", encoding="utf-8") as f: json.dump(obj, f, indent=2, sort_keys=True)

def sha256_file(path:str) -> str:
    import hashlib
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1<<20), b""): h.update(chunk)
    return h.hexdigest()

def isqrt(n:int) -> int:
    return int(math.isqrt(n))

def modS(x:int, S:int) -> int:
    x %= S
    return x if x >= 0 else x + S

# ---------- sector & masks ----------
def sector_index(x:int, y:int, cx:int, cy:int, S:int) -> int:
    ang = math.atan2(y - cy, x - cx)
    if ang < 0: ang += 2.0*math.pi
    s = int((ang/(2.0*math.pi))*S)
    return s if s < S else S-1

def band_allows(r:int, s:int, band:dict, S:int) -> bool:
    """Return True if (r,s) is allowed by the band mask."""
    r0, r1 = int(band["r_min"]), int(band["r_max"])
    if not (r0 <= r <= r1):
        return False
    # segmentation: if seg_period == 0, isotropic (all sectors kept)
    k = int(band.get("seg_period", 0))
    if k <= 0:
        return True
    # segmentation keeps segments centered every seg_period sectors, with half-width seg_half
    seg_half = int(band.get("seg_half", 1))  # half-width in sectors
    offset   = int(band.get("seg_offset", 0))
    # compute nearest segment center to sector s
    # keep if s is within seg_half of one of the K periodic centers
    # centers at: offset + m*(S/k), rounded to nearest int
    base = offset % S
    step = max(1, S // k)
    # find distance to nearest periodic center
    # because S/k may not be integer-dividing, we scan k centers
    keep = False
    for m in range(k):
        c = modS(base + m*step, S)
        d = min(abs(s - c), S - abs(s - c))
        if d <= seg_half:
            keep = True
            break
    return keep

# ---------- geometry & per-shell counts ----------
def build_counts(N:int, cx:int, cy:int, S:int, bands:List[dict]) -> Tuple[Dict[int,int], Dict[int,List[int]], int]:
    """
    Count allowed cells per shell and per (shell,sector) according to band masks.
    Returns (shell_counts, shell_sector_counts, R_edge).
    """
    shell_counts: Dict[int,int] = {}
    shell_sector_counts: Dict[int,List[int]] = {}
    for y in range(N):
        for x in range(N):
            r = isqrt((x - cx)*(x - cx) + (y - cy)*(y - cy))
            s = sector_index(x, y, cx, cy, S)
            allowed = False
            for band in bands:
                if band_allows(r, s, band, S):
                    allowed = True
                    break
            if not allowed:
                continue
            shell_counts[r] = shell_counts.get(r, 0) + 1
            if r not in shell_sector_counts:
                shell_sector_counts[r] = [0]*S
            shell_sector_counts[r][s] += 1
    R_edge = min(cx, cy, (N-1)-cx, (N-1)-cy)
    return shell_counts, shell_sector_counts, R_edge

# ---------- present-act control: DDA 1/r per shell ----------
def simulate_dda(shell_counts: Dict[int,int], H:int, rate_num:int) -> Dict[int,int]:
    A = {r: 0 for r in shell_counts}
    F = {r: 0 for r in shell_counts}
    for _ in range(H):
        for r in shell_counts.keys():
            if r == 0:  # skip center
                continue
            A[r] += rate_num
            if A[r] >= r:
                F[r] += 1
                A[r] -= r
    return F  # integer fires per shell

# ---------- diagnostics ----------
def build_linear_bins(r_min:int, r_max:int, width:int) -> List[Tuple[int,int]]:
    bins = []
    stop = r_max - ((r_max - r_min + 1) % width)
    if stop < r_min + width - 1:
        return bins
    r = r_min
    while r + width - 1 <= stop:
        bins.append((r, r+width-1))
        r += width
    return bins

def radial_profile(shell_counts: Dict[int,int], fires: Dict[int,int], H:int,
                   bins: List[Tuple[int,int]]) -> List[float]:
    """Equal-Δr binned per-annulus per-tick profile (one value per bin)."""
    vals = []
    for lo, hi in bins:
        v = 0.0
        for rr in range(lo, hi+1):
            v += shell_counts.get(rr,0) * (fires.get(rr,0)/H)
        vals.append(v)
    return vals

def cv(arr: List[float]) -> float:
    if not arr: return float("nan")
    mu = sum(arr)/len(arr)
    if mu == 0.0: return float("inf")
    s2 = sum((x-mu)*(x-mu) for x in arr)/len(arr)
    return math.sqrt(s2)/mu

def mean(arr: List[float]) -> float:
    return (sum(arr)/len(arr)) if arr else float("nan")

def sector_profile(shell_sector_counts: Dict[int,List[int]], fires: Dict[int,int], H:int,
                   r0:int, r1:int, S:int) -> List[float]:
    """Per-sector totals restricted to shells r0..r1."""
    prof = [0.0]*S
    for r in range(r0, r1+1):
        if r not in shell_sector_counts: continue
        rate = fires.get(r,0)/H
        row = shell_sector_counts[r]
        for s in range(S):
            prof[s] += row[s]*rate
    return prof

def r2_kfold(profile: List[float], K:int) -> float:
    """R^2 for the best K-fold harmonic a0 + a1 cos(Kθ) + b1 sin(Kθ) against the sector profile."""
    S = len(profile)
    if S == 0 or K <= 0: return float("nan")
    mu = sum(profile)/S
    xs = list(range(S))
    cosv = [math.cos(2.0*math.pi*K*x/S) for x in xs]
    sinv = [math.sin(2.0*math.pi*K*x/S) for x in xs]
    c = sum((profile[i]-mu)*cosv[i] for i in range(S))
    s = sum((profile[i]-mu)*sinv[i] for i in range(S))
    # reconstruct
    rec = [mu + (c*cosv[i] + s*sinv[i])/(S/2.0) for i in range(S)]
    ss_tot = sum((profile[i]-mu)*(profile[i]-mu) for i in range(S))
    ss_res = sum((profile[i]-rec[i])*(profile[i]-rec[i]) for i in range(S))
    return 1.0 - (ss_res/ss_tot if ss_tot>0 else 0.0)

def visibility(profile: List[float]) -> float:
    if not profile: return float("nan")
    mx, mn = max(profile), min(profile)
    return (mx - mn)/(mx + mn) if (mx + mn) > 0 else 0.0

# ---------- run single-center scene ----------
def run_scene(M:dict) -> dict:
    N   = int(M["grid"]["N"])
    cx  = int(M["grid"]["cx"])
    cy  = int(M["grid"]["cy"])
    S   = int(M["sectors"]["S"])
    H   = int(M["H"])
    rn  = int(M["rate_num"])

    bands = M["bands"]  # list of band dicts
    slope_cfg = M["slope"]
    plat_cfg  = M["plateau"]
    acc      = M["acceptance"]

    # counts & control
    shell_counts, shell_sector_counts, R_edge = build_counts(N, cx, cy, S, bands)
    fires = simulate_dda(shell_counts, H, rn)

    # shared interior radial window for diagnostics
    outer_margin = int(M.get("outer_margin", 24))
    r_max_glob = R_edge - outer_margin
    r_min_bins = int(plat_cfg["r_min"])
    W          = int(plat_cfg["shells_per_bin"])
    outer_frac = float(plat_cfg["outer_frac"])
    bins = build_linear_bins(r_min_bins, r_max_glob, W)
    prof = radial_profile(shell_counts, fires, H, bins)

    # Per-band CVs and gap contrasts (use declared bands)
    band_metrics = []
    for b in bands:
        r0, r1 = int(b["r_min"]), int(b["r_max"])
        # collect bins whose range lies fully inside (r0..r1)
        band_bins = [v for (lo,hi), v in zip(bins, prof) if lo >= r0 and hi <= r1]
        cv_band = cv(band_bins)
        mu_band = mean(band_bins)
        band_metrics.append({"r_min": r0, "r_max": r1, "cv": cv_band, "mean": mu_band})

    # gap region = between last two bands (if ≥2 bands)
    gap_ok = True
    gap_ratio = None
    if len(bands) >= 2:
        # assume bands are ordered by radius
        bands_sorted = sorted(bands, key=lambda x: int(x["r_min"]))
        b_in, b_out = bands_sorted[0], bands_sorted[1]
        gap_lo = int(b_in["r_max"]) + 1
        gap_hi = int(b_out["r_min"]) - 1
        gap_bins = [v for (lo,hi), v in zip(bins, prof) if lo >= gap_lo and hi <= gap_hi]
        mu_gap = mean(gap_bins) if gap_bins else 0.0
        # use inner band's mean to normalize
        mu_inner = [bm["mean"] for bm in band_metrics if bm["r_min"] == int(b_in["r_min"])][0]
        gap_ratio = (mu_gap / mu_inner) if mu_inner and mu_inner>0 else 0.0
        gap_ok = (gap_ratio <= float(acc["gap_rel_max"]))

    # Rim segmentation checks
    seg_ok = True
    seg_metrics = None
    # find any band with seg_period > 0
    seg_bands = [b for b in bands if int(b.get("seg_period", 0)) > 0]
    if seg_bands:
        sb = seg_bands[0]  # only one segmented band expected here
        r0s, r1s = int(sb["r_min"]), int(sb["r_max"])
        K = int(sb["seg_period"])
        sp = sector_profile(shell_sector_counts, fires, H, r0s, r1s, S)
        r2K = r2_kfold(sp, K)
        Vseg = visibility(sp)
        seg_metrics = {"r2K": r2K, "visibility": Vseg, "K": K}
        seg_ok = (r2K >= float(acc["seg_r2_min"])) and (Vseg >= float(acc["seg_vis_min"]))

    # Inner-band azimuth flatness (assume first band is isotropic inner band)
    az_ok = True
    az_flat = None
    if len(bands) >= 1:
        ib = sorted(bands, key=lambda x: int(x["r_min"]))[0]
        r0i, r1i = int(ib["r_min"]), int(ib["r_max"])
        spi = sector_profile(shell_sector_counts, fires, H, r0i, r1i, S)
        mu = sum(spi)/len(spi) if spi else 0.0
        if mu == 0.0:
            az_flat = float("inf")
        else:
            s2 = sum((v-mu)*(v-mu) for v in spi)/len(spi)
            az_flat = math.sqrt(s2)/mu
        az_ok = (az_flat <= float(acc["inner_az_flat_max"]))

    # Per-band CV checks
    cv_ok = all((bm["cv"] <= float(acc["band_cv_max"])) for bm in band_metrics)

    # Build audit & PASS
    passed = bool(cv_ok and gap_ok and seg_ok and az_ok)

    return {
        "bins": bins, "radial_profile": prof,
        "band_metrics": band_metrics, "gap_ratio": gap_ratio, "gap_ok": gap_ok,
        "seg_metrics": seg_metrics, "seg_ok": seg_ok,
        "az_flat_inner": az_flat, "az_ok": az_ok,
        "pass": passed
    }

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--manifest", required=True)
    ap.add_argument("--outdir", required=True)
    args = ap.parse_args()

    root = os.path.abspath(args.outdir)
    ensure_dirs(root, ["config","outputs/metrics","outputs/audits","outputs/run_info","logs"])

    # load manifest
    with open(args.manifest, "r", encoding="utf-8") as f:
        M = json.load(f)
    man_out = os.path.join(root, "config", "manifest_h21.json")
    jdump(man_out, M)

    # env
    wtxt(os.path.join(root, "logs", "env.txt"),
         "\n".join([f"utc={utc_ts()}", f"os={os.name}", f"cwd={os.getcwd()}",
                    f"python={sys.version.split()[0]}"]))

    # run scene
    aud = run_scene(M)

    # write metrics
    mpath = os.path.join(root, "outputs/metrics", "h21_radial_profile.csv")
    with open(mpath, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f); w.writerow(["bin_lo","bin_hi","value"])
        for (lo,hi), v in zip(aud["bins"], aud["radial_profile"]):
            w.writerow([lo, hi, f"{v:.6f}"])

    # write audit JSON
    aud_path = os.path.join(root, "outputs/audits", "h21_audit.json")
    jdump(aud_path, aud)

    # summary line
    # For convenience, include key numbers for quick logging
    inner_cv = aud["band_metrics"][0]["cv"] if aud["band_metrics"] else float("nan")
    outer_cv = aud["band_metrics"][1]["cv"] if len(aud["band_metrics"])>1 else float("nan")
    seg_r2   = aud["seg_metrics"]["r2K"] if aud["seg_metrics"] else float("nan")
    seg_vis  = aud["seg_metrics"]["visibility"] if aud["seg_metrics"] else float("nan")

    result = ("H21 PASS={p} cv_inner={ci:.4f} cv_outer={co:.4f} gap_ratio={gr:.3f} "
              "seg_r2={r2:.3f} seg_vis={vs:.3f} az_flat_inner={az:.3f}"
              .format(p=aud["pass"], ci=inner_cv, co=outer_cv,
                      gr=(aud['gap_ratio'] if aud['gap_ratio'] is not None else float('nan')),
                      r2=(seg_r2 if not math.isnan(seg_r2) else float('nan')),
                      vs=(seg_vis if not math.isnan(seg_vis) else float('nan')),
                      az=(aud['az_flat_inner'] if aud['az_flat_inner'] is not None else float('nan'))))
    wtxt(os.path.join(root, "outputs/run_info", "result_line.txt"), result)
    print(result)

if __name__ == "__main__":
    main()
